# this script does some final pre-processing of the Spark output to be most useful for plotting

library(data.table)
library(tidyverse)
library(dtplyr)
library(parallel)
library(Hmisc)


### Helper Functions ------------------------------------------------

# given a vector of shares (n_0, n_1, n_2, ... n_100), calculate the mean (audience score)
# assuming n_0 corresponds to -1 and n_100 to +1.
domain_sample_mean <- function (urls_df, d) {
    d <- filter(urls_df, domain==d) |> select(n_000:n_100)
    w <- colSums(d)
    x <- Hmisc::wtd.mean(seq(1, -1, length.out = 101), w)
    return(x)
    }

# given a vector of shares (n_0, n_1, n_2, ... n_100), calculate the standard deviation (audience score)
# assuming n_0 corresponds to -1 and n_100 to +1.
domain_sample_sd <- function (urls_df, d) {
    d <- filter(urls_df, domain==d) |> select(n_000:n_100)
    w <- colSums(d)
    x <- Hmisc::wtd.var(seq(1, -1, length.out = 101), w)
    return(sqrt(x))
    }


# Given a subset dataframe `u` consisting of all URLs for a domain, calculate the
# 99% (and substantive) confidence interval around the domain score and determine which URLs
# fall inside/outside the intervals
add_analytic_ci <- function(u, source, m = 0.1, 
                            url_count_var = "n_shares", domain_count_var = "domain_shares",
                            url_score_var = "url_score", domain_score_var = "domain_score", domain_sd_var = "domain_sd") {
    subset_df <- as_tibble(filter(u, domain == source))

    min_count <- min(subset_df[[url_count_var]], na.rm = TRUE)
    # min_count must be at least 11 (minimum value across all platforms)
    min_count <- max(min_count, 11)
    max_count <- subset_df[[domain_count_var]][1]
    max_count <- min(max_count, 100000000)

    ci_df <- tibble(n = unique(u[[url_count_var]]))
#    ci_df <- tibble(n = seq(min_count, max_count))
    domain_score <- subset_df[[domain_score_var]][1]
    domain_sd <- subset_df[[domain_sd_var]][1]
    # TODO: this is wonky - should find a more elegant way to 
    # trim ci_df
    ci_df <- filter(ci_df, n %in% unique(u[[url_count_var]]))

    ci_df <- ci_df |>
        mutate(ci.upper = domain_score + (2.57 * domain_sd / sqrt(n)),
              ci.lower = domain_score - (2.57 * domain_sd /sqrt(n)))
    ci_df$ci.upper[ci_df$ci.upper > 1] <- NA
    ci_df$ci.lower[ci_df$ci.lower < -1] <- NA

    ci_df <- ci_df |>
        mutate(ci.upper.tost = ci.upper + m, ci.lower.tost = ci.lower - m)
    ci_df$ci.upper.tost[ci_df$ci.upper.tost > 1] <- NA
    ci_df$ci.lower.tost[ci_df$ci.lower.tost < -1] <- NA
    
   u <- subset_df |>
    left_join(ci_df, by = join_by({{url_count_var}} == n)) |>
    rename(score := {{ url_score_var }}) |>
    mutate(
        stat_sig = pmap_int(
            list(s = score, l = ci.lower, u = ci.upper), 
            \(s, l, u) ((s > coalesce(u, 1)) || (s < coalesce(l, -1)))),
        sub_sig_left = pmap_int(
            list(s = score, l = ci.lower.tost, u = ci.upper.tost), 
            \(s, l, u) ((s < coalesce(l, -1)))),
        sub_sig_right = pmap_int(
            list(s = score, l = ci.lower.tost, u = ci.upper.tost), 
            \(s, l, u) ((s > coalesce(u, 1)))),
    ) |>
    rename ({{url_score_var}} := score)

    u$sig_level <- NA
    u$sig_level[(u$stat_sig == 0)] <- "Not Statistically Significant"
    u$sig_level[(u$stat_sig == 1)&(u$sub_sig_left == 0)&(u$sub_sig_right == 0)] <- "Statistically But Not Substantively Significant"
    u$sig_level[(u$sub_sig_left == 1)] <- "Substantively Significantly Left"
    u$sig_level[(u$sub_sig_right == 1)] <- "Substantively Significantly Right"
    u$sig_level <- as.factor(u$sig_level)
    u <- select(u, -domain)
    return(u)

}


# load data
urls_df <- data.table::fread("data/raw/simulation_input.tsv")
urls <- data.table::fread("data/raw/url_reference_table.tsv", sep = "\t")

# load simulation data
load("data/sim/vecs.RData")
load("data/sim/domain_vecs.RData")
load("data/sim/sim_urls.RData")

urls_df <- urls_df |> left_join(urls, by = c("url", "domain"))


urls_df <- bind_cols(urls_df, select(simulated_urls, -domain))

urls_df <- filter(urls_df, !is.na(num_shares))


domains <- urls_df |> 
    group_by(domain) |> 
    summarise(domain_shares = sum(num_shares), n_urls=n()) |>
    filter(n_urls > 10, domain_shares > 1000)


 domains <- domains |> mutate(domain_score = map_dbl(domain, \(x) domain_sample_mean(urls_df, x)),
                              domain_sd = map_dbl(domain, \(x) domain_sample_sd(urls_df, x))) 


# NOTE: changed left join to inner join after thresholding domains
urls_df <- inner_join(urls_df, domains, by = join_by(domain == domain))


# TODO: code is a little inconsistent in naming here
urls_df$n_shares <- urls_df$num_shares

urls_df <- urls_df |> 
    group_by(domain) |> 
    group_modify(~ add_analytic_ci(.x, .y$domain), .keep=TRUE) |> 
    ungroup()



save(list=c("urls_df", "domains"), file="data/plotting_data.RData")
